# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from functools import partial

import numpy as np

from ..defaults import _handle_default
from ..fixes import _safe_svd
from ..utils import eigh, logger, sqrtm_sym, warn

# For the reference implementation of eLORETA (force_equal=False),
# 0 < loose <= 1 all produce solutions that are (more or less)
# the same as free orientation (loose=1) and quite different from
# loose=0 (fixed). If we do force_equal=True, we get a visibly smooth
# transition from 0->1. This is probably because this mode behaves more like
# sLORETA and dSPM in that it weights each orientation for a given source
# uniformly (which is not the case for the reference eLORETA implementation).
#
# If we *reapply the orientation prior* after each eLORETA iteration,
# we can preserve the smooth transition without requiring force_equal=True,
# which is probably more representative of what eLORETA should do. But this
# does not produce results that pass the eye test.


def _compute_eloreta(inv, lambda2, options):
    """Compute the eLORETA solution."""
    from .inverse import _compute_reginv, compute_rank_inverse

    options = _handle_default("eloreta_options", options)
    eps, max_iter = options["eps"], options["max_iter"]
    force_equal = bool(options["force_equal"])  # None means False

    # Reassemble the gain matrix (should be fast enough)
    if inv["eigen_leads_weighted"]:
        # We can probably relax this if we ever need to
        raise RuntimeError("eLORETA cannot be computed with weighted eigen leads")
    G = np.dot(
        inv["eigen_fields"]["data"].T * inv["sing"], inv["eigen_leads"]["data"].T
    )
    del inv["eigen_leads"]["data"]
    del inv["eigen_fields"]["data"]
    del inv["sing"]
    G = G.astype(np.float64)
    n_nzero = compute_rank_inverse(inv)
    G /= np.sqrt(inv["source_cov"]["data"])
    # restore orientation prior
    source_std = np.ones(G.shape[1])
    if inv["orient_prior"] is not None:
        source_std *= np.sqrt(inv["orient_prior"]["data"])
    G *= source_std
    # We do not multiply by the depth prior, as eLORETA should compensate for
    # depth bias.
    n_src = inv["nsource"]
    n_chan, n_orient = G.shape
    n_orient //= n_src
    assert n_orient in (1, 3)
    logger.info("    Computing optimized source covariance (eLORETA)...")
    if n_orient == 3:
        logger.info(
            f"        Using {'uniform' if force_equal else 'independent'} "
            "orientation weights"
        )
    # src, sens, 3
    G_3 = _get_G_3(G, n_orient)
    if n_orient != 1 and not force_equal:
        # Outer product
        R_prior = source_std.reshape(n_src, 1, 3) * source_std.reshape(n_src, 3, 1)
    else:
        R_prior = source_std**2

    # The following was adapted under BSD license by permission of Guido Nolte
    if force_equal or n_orient == 1:
        R_shape = (n_src * n_orient,)
        R = np.ones(R_shape)
    else:
        R_shape = (n_src, n_orient, n_orient)
        R = np.empty(R_shape)
        R[:] = np.eye(n_orient)[np.newaxis]
    R *= R_prior
    _this_normalize_R = partial(
        _normalize_R,
        n_nzero=n_nzero,
        force_equal=force_equal,
        n_src=n_src,
        n_orient=n_orient,
    )
    G_R_Gt = _this_normalize_R(G, R, G_3)
    extra = " (this make take a while)" if n_orient == 3 else ""
    logger.info(f"        Fitting up to {max_iter} iterations{extra}...")
    for kk in range(max_iter):
        # 1. Compute inverse of the weights (stabilized) and C
        s, u = eigh(G_R_Gt)
        s = abs(s)
        sidx = np.argsort(s)[::-1][:n_nzero]
        s, u = s[sidx], u[:, sidx]
        with np.errstate(invalid="ignore"):
            s = np.where(s > 0, 1 / (s + lambda2), 0)
        N = np.dot(u * s, u.T)
        del s

        # Update the weights
        R_last = R.copy()
        if n_orient == 1:
            R[:] = 1.0 / np.sqrt((np.dot(N, G) * G).sum(0))
        else:
            M = np.matmul(np.matmul(G_3, N[np.newaxis]), G_3.swapaxes(-2, -1))
            if force_equal:
                _, s = sqrtm_sym(M, inv=True)
                R[:] = np.repeat(1.0 / np.mean(s, axis=-1), 3)
            else:
                R[:], _ = sqrtm_sym(M, inv=True)
        R *= R_prior  # reapply our prior, eLORETA undoes it
        G_R_Gt = _this_normalize_R(G, R, G_3)

        # Check for weight convergence
        delta = np.linalg.norm(R.ravel() - R_last.ravel()) / np.linalg.norm(
            R_last.ravel()
        )
        logger.debug(
            f"            Iteration {kk + 1} / {max_iter} ...{extra} ({delta:0.1e})"
        )
        if delta < eps:
            logger.info(
                f"        Converged on iteration {kk} ({delta:.2g} < {eps:.2g})"
            )
            break
    else:
        warn(f"eLORETA weight fitting did not converge (>= {eps})")
    del G_R_Gt
    logger.info("        Updating inverse with weighted eigen leads")
    G /= source_std  # undo our biasing
    G_3 = _get_G_3(G, n_orient)
    _this_normalize_R(G, R, G_3)
    del G_3
    if n_orient == 1 or force_equal:
        R_sqrt = np.sqrt(R)
    else:
        R_sqrt = sqrtm_sym(R)[0]
    assert R_sqrt.shape == R_shape
    A = _R_sqrt_mult(G, R_sqrt)
    del R, G  # the rest will be done in terms of R_sqrt and A
    eigen_fields, sing, eigen_leads = _safe_svd(A, full_matrices=False)
    del A
    inv["sing"] = sing
    inv["reginv"] = _compute_reginv(inv, lambda2)
    inv["eigen_leads_weighted"] = True
    inv["eigen_leads"]["data"] = _R_sqrt_mult(eigen_leads, R_sqrt).T
    inv["eigen_fields"]["data"] = eigen_fields.T
    # XXX in theory we should set inv['source_cov'] properly.
    # For fixed ori (or free ori with force_equal=True), we can as these
    # are diagonal matrices. But for free ori without force_equal, it's a
    # block diagonal 3x3 and we have no efficient way of storing this (and
    # storing a covariance matrix with (20484 * 3) ** 2 elements is not going
    # to work. So let's just set to nan for now.
    # It's not used downstream anyway now that we set
    # eigen_leads_weighted = True.
    inv["source_cov"]["data"].fill(np.nan)
    logger.info("[done]")


def _normalize_R(G, R, G_3, n_nzero, force_equal, n_src, n_orient):
    """Normalize R so that lambda2 is consistent."""
    if n_orient == 1 or force_equal:
        R_Gt = R[:, np.newaxis] * G.T
    else:
        R_Gt = np.matmul(R, G_3).reshape(n_src * 3, -1)
    G_R_Gt = G @ R_Gt
    norm = np.trace(G_R_Gt) / n_nzero
    G_R_Gt /= norm
    R /= norm
    return G_R_Gt


def _get_G_3(G, n_orient):
    if n_orient == 1:
        return None
    else:
        return G.reshape(G.shape[0], -1, n_orient).transpose(1, 2, 0)


def _R_sqrt_mult(other, R_sqrt):
    """Do other @ R ** 0.5."""
    if R_sqrt.ndim == 1:
        assert other.shape[1] == R_sqrt.size
        out = R_sqrt * other
    else:
        assert R_sqrt.shape[1:3] == (3, 3)
        assert other.shape[1] == np.prod(R_sqrt.shape[:2])
        assert other.ndim == 2
        n_src = R_sqrt.shape[0]
        n_chan = other.shape[0]
        out = (
            np.matmul(R_sqrt, other.reshape(n_chan, n_src, 3).transpose(1, 2, 0))
            .reshape(n_src * 3, n_chan)
            .T
        )
    return out
